import pickle
import numpy as np

import torch

from rdkit import Chem
from rdkit.Chem.rdchem import HybridizationType
from rdkit.Chem.rdchem import BondType as BT
from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles
from rdkit import RDLogger

from qm9.data.prepare import prepare_dataset

RDLogger.DisableLog('rdApp.*')

ATOM_LIST = list(range(1, 119))
CHIRALITY_LIST = [
    Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
    Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
    Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
    Chem.rdchem.ChiralType.CHI_OTHER
]
BOND_LIST = [BT.SINGLE, BT.DOUBLE, BT.TRIPLE, BT.AROMATIC]
BONDDIR_LIST = [
    Chem.rdchem.BondDir.NONE,
    Chem.rdchem.BondDir.ENDUPRIGHT,
    Chem.rdchem.BondDir.ENDDOWNRIGHT
]

split_dir = 'Split_QM9'
smiles_dir = 'Split_QM9/SMILES'
ring_dir = 'Split_QM9/ring'

def save_dict_to_pickle(dictionary, filename):
    with open(filename, 'wb') as picklefile:
        pickle.dump(dictionary, picklefile)


def combine_dicts(dict1, dict2, dict3):
    combined_dict = {}
    for key in dict1.keys():
        tensor1 = dict1[key]
        tensor2 = dict2[key]
        tensor3 = dict3[key]
        combined_tensor = torch.cat((tensor1, tensor2, tensor3), dim=0)
        combined_dict[key] = combined_tensor
    return combined_dict


def get_qm_full(datadir, dataset):
    subset = None
    splits = None
    # download_dataset_qm9(datadir, dataset, splits)
    datafiles = prepare_dataset(
        datadir, 'qm9', subset, splits, force_download=False)
    # Load downloaded/processed datasets
    datasets = {}
    for split, datafile in datafiles.items():
        with np.load(datafile) as f:
            datasets[split] = {key: torch.from_numpy(
                val) for key, val in f.items()}
    # Merge
    dataset = datasets['train']
    full_dict = combine_dicts(datasets['train'], datasets['test'], datasets['valid'])
    return full_dict


def get_smiles(full_dict, xyz_dir):
    # Get SMILES
    smiles_list = []
    indices = full_dict['index']
    for index in indices:
        xyz_file = xyz_dir + '%06d.xyz' % (index.item())
        with open(xyz_file, 'r') as datafile:
            lines = [line for line in datafile.readlines()]
        num_atoms = int(lines[0])
        smiles = lines[num_atoms + 3].split()[0]
        smiles_list.append(smiles)
    return smiles_list


# Scaffold Analysis
def _generate_scaffold(smiles, include_chirality=False):
    mol = Chem.MolFromSmiles(smiles)
    scaffold = MurckoScaffoldSmiles(mol=mol, includeChirality=include_chirality)
    return scaffold


def generate_scaffolds(dataset, log_every_n=1000):
    scaffolds = {}
    data_len = len(dataset)
    print(data_len)

    print("About to generate scaffolds")
    for ind, smiles in enumerate(dataset):
        if ind % log_every_n == 0:
            print("Generating scaffold %d/%d" % (ind, data_len))
        scaffold = _generate_scaffold(smiles)
        if scaffold not in scaffolds:
            scaffolds[scaffold] = [ind]
        else:
            scaffolds[scaffold].append(ind)

    # Sort from largest to smallest scaffold sets
    scaffolds = {key: sorted(value) for key, value in scaffolds.items()}
    scaffold_sets = [
        scaffold_set for (scaffold, scaffold_set) in sorted(
            scaffolds.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True)
    ]
    return scaffold_sets, scaffolds


def extract_values_by_index(dictionary, index_list):
    new_dict = {}
    for key, value in dictionary.items():
        tensor = value
        extracted_value = tensor[index_list]
        new_dict[key] = extracted_value
    return new_dict


def split_full_dataset(full_dict, scaffold_smiles_list):
    # Split
    train_limit = 100000
    test_limit = 15000
    train_dict = {}
    split_train_index = []
    split_valid_index = []
    split_test_index = []
    split_scaffolds_smiles = [{}, {}, {}]
    split_scaffolds_lists = [[], [], []]

    for scaffold_smile in scaffold_smiles_list:
        index_list = scaffolds[scaffold_smile]
        scaffold_len = len(index_list)
        if len(split_train_index) + scaffold_len <= train_limit:
            split_train_index += index_list
            split_scaffolds_smiles[0][scaffold_smile] = index_list
            split_scaffolds_lists[0].append(scaffold_smile)
        elif len(split_test_index) + scaffold_len <= test_limit:
            split_test_index += index_list
            split_scaffolds_smiles[1][scaffold_smile] = index_list
            split_scaffolds_lists[1].append(scaffold_smile)
        else:
            split_valid_index += index_list
            split_scaffolds_smiles[2][scaffold_smile] = index_list
            split_scaffolds_lists[2].append(scaffold_smile)

    # Generate a new dictionary with extracted values
    split_train_dict = extract_values_by_index(full_dict, split_train_index)
    split_test_dict = extract_values_by_index(full_dict, split_test_index)
    split_valid_dict = extract_values_by_index(full_dict, split_valid_index)
    return {'ClassI': split_train_dict, 'ClassII': split_test_dict, 'ClassIII': split_valid_dict}

def generate_scaffolds_dict(xyz_dir, text_list, splited_scaffold_datasets):
    for text in text_list:
        smiles_list = get_smiles(splited_scaffold_datasets[text], xyz_dir)
        print('Analyze Scaffolds')
        index_set, scaffolds = generate_scaffolds(smiles_list, 10000)
        save_dict_to_pickle(scaffolds, f'{split_dir}/scaffolds_dict_{text}.pickle')


# Add mask
def get_scaffold_mask(text, dataset, smiles_list, scaffolds_list):
    default_mask = torch.zeros_like(dataset['charges'])
    for molecule_index, (charges, smiles, scaffold) in enumerate(zip(dataset['charges'], smiles_list, scaffolds_list)):
        # Convert the SMILES string and scaffold to molecules
        molecule = Chem.MolFromSmiles(smiles)
        scaffold_mol = Chem.MolFromSmiles(scaffold)
        # scaffold_mol = Chem.MolFromSmarts(scaffold)

        # Find the indices of the scaffold atoms in the molecule
        try:
            scaffold_indices = molecule.GetSubstructMatches(scaffold_mol)
        except:
            # For not error
            scaffold_mol = Chem.MolFromSmarts(scaffold)
            scaffold_indices = molecule.GetSubstructMatches(scaffold_mol)
        # print(scaffold_indices)
        if len(scaffold_indices) == 0:
            continue

        # Get the atoms in the molecule
        atoms = [atom.GetSymbol() for atom in molecule.GetAtoms()]
        scaffold_mol_atom_num_list = [atom.GetAtomicNum() for atom in scaffold_mol.GetAtoms()]

        # Get the atomic numbers of the atoms in the molecule
        atoms_num_list = [atom.GetAtomicNum() for atom in molecule.GetAtoms()]

        # Determine which atoms are in the scaffold and which are not
        in_scaffold = [i in scaffold_indices[0] for i in range(len(atoms))]
        not_in_scaffold = [i not in scaffold_indices[0] for i in range(len(atoms))]

        # for scaffold_atom_index in scaffold_indices:
        smiles_atom_index = 0
        for charge_index in range(len(charges)):
            if charges[charge_index] != 1:
                if atoms_num_list[smiles_atom_index] == charges[charge_index]:
                    if in_scaffold[smiles_atom_index]:
                        default_mask[molecule_index][charge_index] = 1
                    smiles_atom_index += 1
            if charges[charge_index] == 0 or smiles_atom_index == len(atoms_num_list):
                break
    return default_mask

def add_scaffold_mask_and_save(text_list, splited_scaffold_datasets):
    for text in text_list:
        smiles_list = get_smiles(splited_scaffold_datasets[text], xyz_dir)
        scaffold_smiles_list = [_generate_scaffold(smiles) for smiles in smiles_list]
        splited_scaffold_datasets[text]['scaffold_mask'] = get_scaffold_mask(text, splited_scaffold_datasets[text],
                                                                             smiles_list, scaffold_smiles_list)
        save_dict_to_pickle(splited_scaffold_datasets[text], f'{split_dir}/QM9_scaffold_split_{text}_dataset.pickle')
        save_dict_to_pickle(smiles_list, f'{smiles_dir}/data_scaffold_{text}_smiles.pickle')


## RING NUMBER SPLIT
def get_rings(smiles):
    mol = Chem.MolFromSmiles(smiles)
    ring_info = mol.GetRingInfo()
    return ring_info.NumRings()


def generate_index_dict(lst):
    index_dict = {}
    for index, item in enumerate(lst):
        if item in index_dict:
            index_dict[item].append(index)
        else:
            index_dict[item] = [index]
    return index_dict


# Generate a new dictionary with extracted values
def extract_values_by_index(dictionary, index_list):
    new_dict = {}
    for key, value in dictionary.items():
        tensor = value
        extracted_value = tensor[index_list]
        new_dict[key] = extracted_value
    return new_dict


def get_ring_mask(dataset, smiles_list, rings_list):
    default_mask = torch.zeros_like(dataset['charges'])

    for molecule_index, (charges, smiles, rings_list) in enumerate(zip(dataset, smiles_list, rings_list)):
        # Convert the SMILES string and scaffold to molecules
        if len(rings_list) == 0:
            continue
        scaffold_indices = []
        for rings in rings_list:
            for ring in rings:
                if ring not in scaffold_indices:
                    scaffold_indices.append(ring)

        molecule = Chem.MolFromSmiles(smiles)

        # Get the atoms in the molecule
        atoms = [atom.GetSymbol() for atom in molecule.GetAtoms()]

        # Get the atomic numbers of the atoms in the molecule
        atoms_num_list = [atom.GetAtomicNum() for atom in molecule.GetAtoms()]

        # Determine which atoms are in the scaffold and which are not
        in_ring = [i in scaffold_indices for i in range(len(atoms))]
        not_in_ring = [i not in scaffold_indices for i in range(len(atoms))]

        # for scaffold_atom_index in scaffold_indices:
        smiles_atom_index = 0
        for charge_index in range(len(charges)):
            if charges[charge_index] != 1:
                if atoms_num_list[smiles_atom_index] == charges[charge_index]:
                    if in_ring[smiles_atom_index]:
                        default_mask[molecule_index][charge_index] = 1
                    smiles_atom_index += 1
            if charges[charge_index] == 0 or smiles_atom_index == len(atoms_num_list):
                break
    return default_mask


def get_molecule_rings(smiles):
    molecule = Chem.MolFromSmiles(smiles)

    # Generate the rings
    ring_info = molecule.GetRingInfo()

    # Get the atom indices for each ring
    rings = ring_info.AtomRings()

    return rings


def split_and_add_ring_mask_save(smiles_list, full_dict):
    rings = [get_rings(s) for s in smiles_list]
    ring_index_dict = generate_index_dict(rings)

    # Split Ring Numbers
    split_train_index = []
    split_valid_index = []
    for key in [0, 1, 2, 3]:
        ring_index_list = ring_index_dict[key]
        size = int(len(ring_index_list) * 0.9)
        split_train_index.extend(ring_index_list[:size])
        split_valid_index.extend(ring_index_list[size:])
    split_test_index = []
    for key in [4, 5, 6, 7, 8]:
        split_test_index.extend(ring_index_dict[key])

    split_train_dict = extract_values_by_index(full_dict, split_train_index)
    split_test_dict = extract_values_by_index(full_dict, split_test_index)
    split_valid_dict = extract_values_by_index(full_dict, split_valid_index)

    ring_datasets = {'03_train': split_train_dict, '03_valid': split_valid_dict, '48_test': split_test_dict}

    ring_text = ['03_train', '03_valid', '48_test']
    for text in ring_text:
        smiles_list = get_smiles(ring_datasets[text], xyz_dir)
        ring_list = [get_molecule_rings(smiles) for smiles in smiles_list]
        ring_datasets[text]['ring_mask'] = get_ring_mask(ring_datasets[text], smiles_list, ring_list)
        with open(f'{split_dir}/QM9_ring_{text}_dataset.pickle', 'wb') as file:
            pickle.dump(ring_datasets[text], file)


def split_by_rings(smiles_list, qm9_full_data):
    rings = [get_rings(s) for s in smiles_list]
    ring_index_dict = generate_index_dict(rings)
    for key, value in ring_index_dict.items():
        split_ring_dataset = extract_values_by_index(qm9_full_data, value)

        sub_smiles_list = get_smiles(split_ring_dataset, xyz_dir)
        ring_list = [get_molecule_rings(smiles) for smiles in sub_smiles_list]
        split_ring_dataset['ring_mask'] = get_ring_mask(split_ring_dataset, sub_smiles_list, ring_list)
        with open(f'{split_dir}/rings/QM9_ring_number_{key}_dataset.pickle', 'wb') as file:
            pickle.dump(split_ring_dataset, file)
        with open(f'{smiles_dir}/data_ring_{key}_smiles.pickle', 'wb') as file:
            pickle.dump(ring_list, file)

if __name__ == '__main__':
    datadir = 'qm9/temp'
    dataset = 'qm9'
    xyz_dir = datadir + '/qm9/xyz/dsgdb9nsd_'
    print('Download and load QM9')
    qm9_full_data = get_qm_full(datadir, dataset)
    print('Analyze SMILES')
    smiles_list = get_smiles(qm9_full_data, xyz_dir)
    print('Analyze Scaffolds')
    index_set, scaffolds = generate_scaffolds(smiles_list, 10000)
    scaffold_smiles_list = list(scaffolds.keys())
    print('Split QM9 according to the scaffolds, into: ClasssI, ClassII, ClassIII')
    splited_scaffold_datasets = split_full_dataset(qm9_full_data, scaffold_smiles_list)
    text_list = {'ClassI', 'ClassII', 'ClassIII'}
    print('Generate QM9 scaffolds dict')
    generate_scaffolds_dict(xyz_dir, text_list, splited_scaffold_datasets)
    add_scaffold_mask_and_save(text_list, splited_scaffold_datasets)
    print('Split QM9 according to the rings, into: ring0-3 as train and valid, ring4-8 as test')
    split_and_add_ring_mask_save(smiles_list, qm9_full_data)
    print('Split QM9 by rings')
    split_by_rings(smiles_list, qm9_full_data)
